package com.icesoft.core.web.suppose.csrf;

import com.icesoft.core.common.helper.Resp;
import com.icesoft.core.web.helper.ResponseUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
import org.springframework.cache.jcache.JCacheCacheManager;
import org.springframework.stereotype.Component;

import javax.cache.Cache;
import javax.cache.configuration.MutableConfiguration;
import javax.cache.expiry.CreatedExpiryPolicy;
import javax.cache.expiry.Duration;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.concurrent.TimeUnit;

@Component
@Slf4j
public class TimeUnitNonceStrFilter {

    private static final long timeUnitExpired = 5 * 60 * 1000;

    private static final String NONCE_STR_NAME = "nonceStr";
    private static final String TIME_UNIT_NAME = "timeUnit";
    private Cache<String, Boolean> cache;
    private static final String CACHE_NAME = "icesoft.nonceStr.5m.key";

    public TimeUnitNonceStrFilter(JCacheCacheManager jCacheCacheManager) {
        MutableConfiguration defaultCacheConfiguration = new MutableConfiguration<>();
        defaultCacheConfiguration
                .setExpiryPolicyFactory(CreatedExpiryPolicy.factoryOf(new Duration(TimeUnit.MILLISECONDS, timeUnitExpired)));
        cache = jCacheCacheManager.getCacheManager().createCache(CACHE_NAME, defaultCacheConfiguration);
    }

    public boolean filter(HttpServletRequest request, HttpServletResponse response) {
        String nonceStr = getNonceStr(request);
        String timeUnit = getTimeUnit(request);
        if (StringUtils.isBlank(nonceStr) || StringUtils.isBlank(timeUnit)) {
            ResponseUtils.writeJson(response, Resp.error("nonceStr和timeUnit不能为空"));
            return false;
        }
        if (nonceStr.length() > 128) {
            ResponseUtils.writeJson(response, Resp.error("nonceStr长度不能超过128位"));
            return false;
        }
        if (cache.get(nonceStr) != null) {
            ResponseUtils.writeJson(response, Resp.error("重复的nonceStr"));
            return false;
        }
        cache.put(nonceStr, true);
        if (!NumberUtils.isDigits(timeUnit)) {
            log.error("timeUnit错误：" + timeUnit);
            ResponseUtils.writeJson(response, Resp.error("timeUnit错误"));
            return false;
        }
        long timestamp = Long.parseLong(timeUnit);
        int length = timeUnit.length();
        if (length != 13 && length != 10) {
            log.error("timeUnit长度错误：" + timeUnit);
            ResponseUtils.writeJson(response, Resp.error("timeUnit长度错误"));
            return false;
        } else if (timeUnit.length() == 10) {
            timestamp = timestamp * 1000;
        }

        if (Math.abs(System.currentTimeMillis() - timestamp) > timeUnitExpired) {
            String msg = "时间错误，请求已失效";
            Instant instant = Instant.ofEpochMilli(timestamp);
            ZoneId zone = ZoneId.systemDefault();
            LocalDateTime userTime = LocalDateTime.ofInstant(instant, zone);
            log.error("timeUnit已失效：{}，{}", timeUnit, userTime.format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
            ResponseUtils.writeJson(response, Resp.error(msg));
            return false;
        }
        return true;
    }

    public static String getNonceStr(HttpServletRequest request) {
        String nonceStr = request.getParameter(NONCE_STR_NAME);
        if (nonceStr == null) {
            nonceStr = request.getHeader(NONCE_STR_NAME);
        }
        return nonceStr;
    }

    public static String getTimeUnit(HttpServletRequest request) {
        String timeUnit = request.getParameter(TIME_UNIT_NAME);
        if (timeUnit == null) {
            timeUnit = request.getHeader(TIME_UNIT_NAME);
        }
        return timeUnit;
    }

}
